import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import init

# TODO: Expand to support tensors larger than 2 dimensions
def random_restrict_fanin(mask: Tensor, fan_in: int) -> Tensor:
    vector_size, num_vectors = nn.init._calculate_fan_in_and_fan_out(mask)
    init.constant_(mask, 0.0)
    if len(mask.shape) == 2:
        for i in range(num_vectors):
            x = torch.randperm(vector_size)[:fan_in]
            mask[i][x] = 1
    else:
        assert False, "Unsupported mask shape, specified: %s" % (str(mask.shape))
    return mask

